from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from tqdm import tqdm
import json
import torch
import random
import numpy as np
import argparse

# reproducibility
seed = 42
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, required=True, 
                    choices=["llama8b", "llama70b", "qwen7b", "qwen72b", "olmo7b", "olmo32b"],
                    help="Model to use")
args = parser.parse_args()
model_name = args.model_name
quant = False
if model_name == "llama8b":
    model_path = "meta-llama/Llama-3.1-8B-Instruct"
elif model_name == "llama70b":
    model_path = "meta-llama/Llama-3.3-70B-Instruct"
    quant = True
elif model_name == "qwen7b":
    model_path = "Qwen/Qwen2.5-7B-Instruct"
elif model_name == "qwen72b":
    model_path = "Qwen/Qwen2.5-72B-Instruct"
    quant = True
elif model_name == "olmo7b":
    model_path = "allenai/OLMo-2-1124-7B-Instruct"
elif model_name == "olmo32b":
    model_path = "allenai/OLMo-2-0325-32B-Instruct"
else:
    raise ValueError("Model not supported")

# Path to the AIME dataset
with open('aime.json', 'r') as json_file:
    data = json.load(json_file)

results = []
tokenizer = AutoTokenizer.from_pretrained(model_path)
bnb_config = BitsAndBytesConfig(
    load_in_8bit=True,        # Enable 8-bit quantization
    llm_int8_threshold=6.0,   # (Optional) Default threshold for LLM.int8()
    llm_int8_skip_modules=None, # (Optional) Skip quantization for specific modules
)

if quant:
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", quantization_config=bnb_config)
else:
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
    
valid_indices = list(range(len(data)))
for i in tqdm(valid_indices):  
    d = data[i]
    prompt = f'''
        Solve the following math problem efficiently and clearly:

            - For simple problems (2 steps or fewer):
            Provide a concise solution with minimal explanation.

            - For complex problems (3 steps or more):
            Use this step-by-step format:

            ## Step 1: [Concise description]
            [Brief explanation and calculations]

            ## Step 2: [Concise description]
            [Brief explanation and calculations]

            ...

            Regardless of the approach, always conclude with:

            Therefore, the final answer is: $\\boxed{{answer}}$. I hope it is correct.

            Where [answer] is just the final number or expression that solves the problem.
        
        Problem: {d["Problem"]}
        '''
             
    messages = [{"role": "user", "content": prompt}]

    input_ids = tokenizer.apply_chat_template(
        messages, add_generation_prompt=True, return_tensors="pt"
    ).to(model.device)
    
    terminators = [
        tokenizer.eos_token_id,
    ]

    outputs = model.generate(
    input_ids, max_new_tokens=8192, eos_token_id=terminators, 
    do_sample=True, temperature=0.9, pad_token_id=tokenizer.eos_token_id, num_return_sequences=5
)

    generations = tokenizer.batch_decode(
        outputs[:, input_ids.shape[-1]:],
        skip_special_tokens=True
    )
    d['model_responses'] = [g.strip() for g in generations]
    results.append(d)

with open(f'5_samples_aime_gen_{model_name}.json', 'w') as json_file:
    json.dump(results, json_file, indent=4)
